package sync

import (
	"context"
	"runtime"
	"sync"
	"sync/atomic"
)

/*
WaitPool 用来一次性批量处理任务
同时限制goroutine数量，防止一个任务一个goroutine
如果任务数小于允许的goroutine数，不推荐使用，而是直接使用WaitGroup
*/
type waitPoolTask func() error
type WaitPool struct {
	ctx     context.Context
	wg      sync.WaitGroup
	task    chan waitPoolTask
	once    sync.Once
	cap     int32 //最大开启goroutine数
	has     int32 //当前开启的goroutine数
	err     error //first error
	onceErr sync.Once
}

func NewWaitPool(ctx context.Context, maxGoroutine int) *WaitPool {
	if maxGoroutine < runtime.NumCPU() {
		maxGoroutine = runtime.NumCPU()
	}
	return &WaitPool{
		ctx:  ctx,
		task: make(chan waitPoolTask),
		cap:  int32(maxGoroutine),
	}
}

func (p *WaitPool) Go(f waitPoolTask) error {
	err := p.ctx.Err()
	if err != nil {
		return err
	}
	if p.err != nil {
		return p.err
	}
	if p.cap == 0 {
		p.wg.Add(1)
		go func() {
			defer p.wg.Done()
			if err := f(); err != nil {
				p.setError(err)
			}
		}()
		return p.err
	}
	select {
	case p.task <- f:
	default:
		has := atomic.AddInt32(&p.has, 1)
		if has <= p.cap {
			//goroutine未达上限，启动一个goroutine
			p.startGo(f)
		} else {
			//goroutine已达上限，等待一个goroutine执行完成
			atomic.StoreInt32(&p.has, p.cap)
			p.task <- f
		}
	}
	return p.err
}

func (p *WaitPool) Err() error {
	return p.err
}

// 调用Wait()后，等待任务执行完毕，不能再开启任务，退出goroutine
func (p *WaitPool) Wait() error {
	p.once.Do(func() {
		p.cap = 0
		close(p.task)
	})
	p.wg.Wait()
	return p.err
}

// 开启一个goroutine执行任务
func (p *WaitPool) startGo(f waitPoolTask) {
	p.wg.Add(1)
	go func() {
		defer p.wg.Done()
		for {
			if er := f(); er != nil {
				p.setError(er)
			}
			if task, ok := <-p.task; ok {
				f = task
			} else {
				break
			}
		}
	}()
}

func (p *WaitPool) setError(err error) {
	p.onceErr.Do(func() {
		p.err = err
	})
}
